import torch
import numpy as np
import scipy.sparse as sp 
import struct
import open3d.ml.torch as ml3d
import imageio
import os
import argparse
import tqdm
import glob
import open3d as o3d
import higra as hg

from tree import PartTree, save_tree, load_tree
from graph import get_tree

parser = argparse.ArgumentParser()

parser.add_argument('exp', type=str)           # positional argument
parser.add_argument('output', type=str)           # positional argument

parser.add_argument('-t', '--threshold', type=float, default=5e-3)      # option that takes a value
parser.add_argument('--k_query', type=int, default=5)  # on/off flag
parser.add_argument('-o', '--n_outliers', type=int, default=2)  # on/off flag
args = parser.parse_args()


def get_segmentation(feat, k_index, distance, num_keep):
    num_points, k = k_index.shape
    rows = torch.arange(num_points).repeat(k)
    cols = k_index.transpose(0, 1).reshape(-1)

    # Batch this to avoid memory overflow
    data = feat[rows]
    cdata = feat[cols]
    BATCH_SIZE = 1000
    for i in range(0, len(rows), BATCH_SIZE):
        data[i:i+BATCH_SIZE] = data[i:i+BATCH_SIZE] - cdata[i:i+BATCH_SIZE]
    data = (((data) ** 2).sum(-1) + 1e-8).sqrt()

    # TODO: Construct higra tree
    # tree, altitudes = hg.bpt_canonical((srcs, tgts.cpu().numpy(), len(feat)), graph_edge_lengths.detach().cpu().numpy().astype(float))
    rows = rows[data < distance]
    cols = cols[data < distance]
    data = data[data < distance]

    graph = sp.csr_matrix((data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())), shape=(len(feat), len(feat)))
    num_components, segmentation = sp.csgraph.connected_components(graph, connection='weak')
    
    count = np.histogram(segmentation, bins=[i for i in range(segmentation.max() + 2)])[0]
    segmentation = count.argsort()[::-1].argsort()[segmentation]
    segmentation[segmentation > num_keep] = num_keep # keep the num_keep largest masks
    
    return segmentation

def write_pointcloud(filename,xyz_points,rgb_points=None):
    """ creates a .pkl file of the point clouds generated
    """

    assert xyz_points.shape[1] == 3,'Input XYZ points should be Nx3 float array'
    if rgb_points is None:
        rgb_points = np.ones(xyz_points.shape).astype(np.uint8)*255
    assert xyz_points.shape == rgb_points.shape,'Input RGB colors should be Nx3 float array and have same size as input XYZ points'
    print((rgb_points.sum(1) > 0).sum())
    # Write header of .ply file
    fid = open(filename,'wb')
    fid.write(bytes('ply\n', 'utf-8'))
    fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8'))
    fid.write(bytes('element vertex %d\n'%(rgb_points.sum(1) > 0).sum(), 'utf-8'))
    fid.write(bytes('property float x\n', 'utf-8'))
    fid.write(bytes('property float y\n', 'utf-8'))
    fid.write(bytes('property float z\n', 'utf-8'))
    fid.write(bytes('property uchar red\n', 'utf-8'))
    fid.write(bytes('property uchar green\n', 'utf-8'))
    fid.write(bytes('property uchar blue\n', 'utf-8'))
    fid.write(bytes('end_header\n', 'utf-8'))

    # Write 3D points to .ply file
    for i in range(xyz_points.shape[0]):
        if  rgb_points[i].sum() > 0:
            fid.write(bytearray(struct.pack("fffccc",xyz_points[i,0],xyz_points[i,1],xyz_points[i,2],
                                            rgb_points[i,0].tobytes(),rgb_points[i,1].tobytes(),
                                            rgb_points[i,2].tobytes())))
    fid.close()

print('Start 3D segmentation')
coords = []
feats = []
query_coords = []
query_depths = []
query_imgs = []

last_epoch = max([int(f) for f in os.listdir(f'results/nerf/{args.exp}')])

num_training_frames = len(glob.glob(f'results/nerf/{args.exp}/{last_epoch}/train_*_d.npy'))
num_test_frames = len(glob.glob(f'results/nerf/{args.exp}/{last_epoch}/[0-9]*_d.npy'))

for i in tqdm.tqdm([i for i in range(num_training_frames)]):
    depth = np.load(f'results/nerf/{args.exp}/{last_epoch}/train_{i:03d}_d.npy').reshape(-1)
    coords.append(np.load(f'results/nerf/{args.exp}/{last_epoch}/train_{i:03d}_surface.npy').reshape(-1, 3)[depth > 0.3])
    feats.append(torch.load(f'results/nerf/{args.exp}/{last_epoch}/train_{i:03d}_f.pth')[depth > 0.3])
    
print('Loading test data')
for i in [i for i in range(num_test_frames)]:
    query_coords.append(np.load(f'results/nerf/{args.exp}/{last_epoch}/{i:03d}_surface.npy').reshape(-1, 3))
    query_depths.append(torch.tensor(np.load(f'results/nerf/{args.exp}/{last_epoch}/{i:03d}_d.npy')).reshape(-1))
    query_imgs.append(imageio.imread(f'results/nerf/{args.exp}/{last_epoch}/{i:03d}.png'))

coords = np.concatenate(coords)
feats = torch.cat(feats)

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(coords)

ind = np.array([i for i in range(len(coords))])

downsample_ind = [i[0] for i in pcd.voxel_down_sample_and_trace(2e-3, coords.min(axis=0), coords.max(axis=0))[2]]
print(f'Downsampling, {len(downsample_ind)}/{len(ind)} points left')
pcd_new = o3d.geometry.PointCloud()
pcd_new.points = o3d.utility.Vector3dVector(coords[downsample_ind])
pcd_new, ind = pcd_new.remove_radius_outlier(nb_points=1, radius=4e-3)

print(f'Cleaning, {len(ind)}/{len(downsample_ind)} points left')

U, S, V = torch.pca_lowrank(feats[downsample_ind][ind].float(), niter=10)
proj_V = V[:, :3].float()
lowrank = torch.matmul(feats[downsample_ind][ind].float(), proj_V)
lowrank = ((lowrank - lowrank.min(0, keepdim=True)[0]) / 
           (lowrank.max(0, keepdim=True)[0] - lowrank.min(0, keepdim=True)[0])).clip(0, 1)

os.makedirs('vis/', exist_ok=True)
write_pointcloud(f'vis/{args.output}.ply', coords[downsample_ind][ind], (lowrank.numpy() * 255).astype(np.uint8))

feats = feats[downsample_ind][ind]
coords = coords[downsample_ind][ind]
points = torch.tensor(coords)
k_graph = 17
num_seg = 200
k_query = args.k_query

nsearch = ml3d.layers.KNNSearch(return_distances=False)
nsearch_w_distance = ml3d.layers.KNNSearch(return_distances=True)
ans = nsearch(points, points, k_graph)
k_index = ans.neighbors_index.reshape(-1, k_graph)[:, 1:].long()

os.makedirs(f'vis/{args.output}/', exist_ok=True)


# Create part hierarchy tree
MIN_POINTS = 1000 # heuristic
IDX = 0
query_img = query_imgs[IDX]
query_coord = query_coords[IDX]
query_depth = query_depths[IDX]
output_dir = f'vis/{args.output}/{IDX}/tree'

tree, altitudes = get_tree(feats, k_index)
num_nodes = tree.root() + 1
root_image_path = f'results/nerf/{args.exp}/{last_epoch}/{IDX:03d}.png'
parttree = PartTree(output_dir, tree.root(),imageio.imread(root_image_path),  root_image_path)

hce = hg.HorizontalCutExplorer(tree, altitudes)

colors = np.random.randint(0, 256, (num_nodes + 2, 3))
colors[-2] = 0
colors[-1] = 255

# print(f"Num cuts: {hce.num_cuts()}")
# for cut_idx in tqdm.tqdm(range(min(50, hce.num_cuts()))): # Go through indices
    # cut = hce.horizontal_cut_from_index(cut_idx)
    # distance = cut.altitude()

for distance in tqdm.tqdm([(i + 1) * 0.01 for i in range(50, 0, -1)]): # Go through altitudes
    cut = hce.horizontal_cut_from_altitude(distance)
    segmentation = cut.labelisation_leaves(tree)

    seg_result = segmentation

    os.makedirs(f'vis/{args.output}/{IDX}', exist_ok=True)
    nn_ans = nsearch_w_distance(points, torch.tensor(query_coord), k_query)
    nn_index = nn_ans.neighbors_index.reshape(-1, k_query).long()
    nn_distance = nn_ans.neighbors_distance.reshape(-1,k_query).max(dim=1)[0]
    segmentation_2d = seg_result[nn_index]

    segmentation_2d = torch.tensor(segmentation_2d).mode(dim=1)[0]

    segmentation_2d = segmentation_2d.reshape(200, 200).numpy()
    np.save(f'vis/{args.output}/{IDX}/{distance:.03f}.npy', segmentation_2d)
    segmentation_2d_colored = colors[segmentation_2d].astype(np.uint8)
    imageio.imsave(f'vis/{args.output}/{IDX}/{distance:.03f}.png', segmentation_2d_colored)


    labels, counts = np.unique(seg_result, return_counts=True)

    # TODO: for labels that are too small. Do KNN to assign label
    unseen = [(label, counts[i]) for i, label in enumerate(labels) if not parttree.exists(label)]
    for new_label, pix_count in unseen:
        if pix_count < MIN_POINTS:
            continue
        # Find lowest common ancestor in part tree
        parent_label = tree.parent(new_label)
        while not parttree.exists(parent_label):
            parent_label = tree.parent(parent_label)
        mask = segmentation_2d==new_label

        # Render and add them to part tree
        seg_img = np.zeros_like(query_img)
        seg_img[mask] = query_img[mask]

        image_path = f'vis/{args.output}/{IDX}/{distance:.03f}/{new_label:03d}.png'
        parttree.add_edge(new_label, seg_img, image_path, parent_label)

parttree.canonize_tree()
# parttree.query_preprocess()
parttree.render_tree(f'vis/{args.output}/{IDX}/tree/tree_with_images')

save_tree(parttree, f'vis/{args.output}/{IDX}/tree/tree.pkl')
